{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "cEMg3TK0eVPj"
   },
   "source": [
    "Copyright (c) MONAI Consortium  \n",
    "Licensed under the Apache License, Version 2.0 (the \"License\");  \n",
    "you may not use this file except in compliance with the License.  \n",
    "You may obtain a copy of the License at  \n",
    "&nbsp;&nbsp;&nbsp;&nbsp;http://www.apache.org/licenses/LICENSE-2.0  \n",
    "Unless required by applicable law or agreed to in writing, software  \n",
    "distributed under the License is distributed on an \"AS IS\" BASIS,  \n",
    "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  \n",
    "See the License for the specific language governing permissions and  \n",
    "limitations under the License.\n",
    "\n",
    "# 3D regression example based on DenseNet\n",
    "\n",
    "This tutorial shows an example of 3D regression task based on DenseNet and array format transforms.\n",
    "\n",
    "Here, the task is given to predict the ages of subjects from MR imagee.\n",
    "\n",
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Project-MONAI/tutorials/blob/main/3d_regression/densenet_training_array.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "P2pXpPRYeVPl"
   },
   "source": [
    "## Setup environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-06-23T00:20:33.913058Z",
     "start_time": "2024-06-23T00:20:31.451820Z"
    },
    "id": "dlZS78A8eVPl"
   },
   "outputs": [],
   "source": [
    "!python -c \"import monai\" || pip install -q \"monai-weekly[nibabel, tqdm]\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "l82hea4heVPm"
   },
   "source": [
    "## Setup imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-06-23T00:20:35.515047Z",
     "start_time": "2024-06-23T00:20:33.921859Z"
    },
    "id": "0mu2D_19eVPm"
   },
   "outputs": [],
   "source": [
    "import logging\n",
    "import os\n",
    "import sys\n",
    "import shutil\n",
    "import tempfile\n",
    "\n",
    "import torch\n",
    "from torch.utils.tensorboard import SummaryWriter\n",
    "import numpy as np\n",
    "\n",
    "import monai\n",
    "from monai.apps import download_and_extract\n",
    "from monai.config import print_config\n",
    "from monai.data import DataLoader, ImageDataset\n",
    "from monai.transforms import (\n",
    "    EnsureChannelFirst,\n",
    "    Compose,\n",
    "    RandRotate90,\n",
    "    Resize,\n",
    "    ScaleIntensity,\n",
    ")\n",
    "from monai.networks.nets import Regressor\n",
    "\n",
    "pin_memory = torch.cuda.is_available()\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "logging.basicConfig(stream=sys.stdout, level=logging.INFO)\n",
    "print_config()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "QVfv5JUVR892"
   },
   "source": [
    "## Setup data directory"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "start_time": "2024-06-23T00:20:35.518443Z"
    },
    "id": "XnU_pzCbeVPn"
   },
   "outputs": [],
   "source": [
    "# Set data directory\n",
    "directory = os.environ.get(\"MONAI_DATA_DIRECTORY\")\n",
    "if directory is not None:\n",
    "    os.makedirs(directory, exist_ok=True)\n",
    "root_dir = tempfile.mkdtemp() if directory is None else directory\n",
    "print(root_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-06-23T00:20:35.524079Z",
     "start_time": "2024-06-23T00:20:35.521258Z"
    },
    "id": "2wxD3j2feVPn"
   },
   "outputs": [],
   "source": [
    "# IXI dataset as a demo, downloadable from https://brain-development.org/ixi-dataset/\n",
    "images = [\n",
    "    os.sep.join([root_dir, \"ixi\", \"IXI314-IOP-0889-T1.nii.gz\"]),\n",
    "    os.sep.join([root_dir, \"ixi\", \"IXI249-Guys-1072-T1.nii.gz\"]),\n",
    "    os.sep.join([root_dir, \"ixi\", \"IXI609-HH-2600-T1.nii.gz\"]),\n",
    "    os.sep.join([root_dir, \"ixi\", \"IXI173-HH-1590-T1.nii.gz\"]),\n",
    "    os.sep.join([root_dir, \"ixi\", \"IXI020-Guys-0700-T1.nii.gz\"]),\n",
    "    os.sep.join([root_dir, \"ixi\", \"IXI342-Guys-0909-T1.nii.gz\"]),\n",
    "    os.sep.join([root_dir, \"ixi\", \"IXI134-Guys-0780-T1.nii.gz\"]),\n",
    "    os.sep.join([root_dir, \"ixi\", \"IXI577-HH-2661-T1.nii.gz\"]),\n",
    "    os.sep.join([root_dir, \"ixi\", \"IXI066-Guys-0731-T1.nii.gz\"]),\n",
    "    os.sep.join([root_dir, \"ixi\", \"IXI130-HH-1528-T1.nii.gz\"]),\n",
    "    os.sep.join([root_dir, \"ixi\", \"IXI607-Guys-1097-T1.nii.gz\"]),\n",
    "    os.sep.join([root_dir, \"ixi\", \"IXI175-HH-1570-T1.nii.gz\"]),\n",
    "    os.sep.join([root_dir, \"ixi\", \"IXI385-HH-2078-T1.nii.gz\"]),\n",
    "    os.sep.join([root_dir, \"ixi\", \"IXI344-Guys-0905-T1.nii.gz\"]),\n",
    "    os.sep.join([root_dir, \"ixi\", \"IXI409-Guys-0960-T1.nii.gz\"]),\n",
    "    os.sep.join([root_dir, \"ixi\", \"IXI584-Guys-1129-T1.nii.gz\"]),\n",
    "    os.sep.join([root_dir, \"ixi\", \"IXI253-HH-1694-T1.nii.gz\"]),\n",
    "    os.sep.join([root_dir, \"ixi\", \"IXI092-HH-1436-T1.nii.gz\"]),\n",
    "    os.sep.join([root_dir, \"ixi\", \"IXI574-IOP-1156-T1.nii.gz\"]),\n",
    "    os.sep.join([root_dir, \"ixi\", \"IXI585-Guys-1130-T1.nii.gz\"]),\n",
    "]\n",
    "\n",
    "# ages of subjects\n",
    "ages = np.array(\n",
    "    [\n",
    "        45.86,\n",
    "        68.27,\n",
    "        29.0,\n",
    "        29.57,\n",
    "        39.47,\n",
    "        48.68,\n",
    "        47.35,\n",
    "        64.19,\n",
    "        46.17,\n",
    "        38.77,\n",
    "        83.81,\n",
    "        72.27,\n",
    "        64.65,\n",
    "        62.09,\n",
    "        70.95,\n",
    "        41.33,\n",
    "        24.0,\n",
    "        33.24,\n",
    "        50.57,\n",
    "        28.12,\n",
    "    ]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "start_time": "2024-06-23T00:20:35.525723Z"
    },
    "id": "1zcWs3THeVPn",
    "is_executing": true
   },
   "outputs": [],
   "source": [
    "if not os.path.isfile(images[0]):\n",
    "    resource = \"https://developer.download.nvidia.com/assets/Clara/monai/tutorials/IXI-T1.tar\"\n",
    "    md5 = \"34901a0593b41dd19c1a1f746eac2d58\"\n",
    "\n",
    "    dataset_dir = os.path.join(root_dir, \"ixi\")\n",
    "    tarfile_name = f\"{dataset_dir}.tar\"\n",
    "\n",
    "    download_and_extract(resource, tarfile_name, dataset_dir, md5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "P3-MIx_lSNKF"
   },
   "source": [
    "## Create data loaders"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "yVvt3XtHeVPn",
    "is_executing": true
   },
   "outputs": [],
   "source": [
    "# Define transforms\n",
    "train_transforms = Compose([ScaleIntensity(), EnsureChannelFirst(), Resize((96, 96, 96)), RandRotate90()])\n",
    "\n",
    "val_transforms = Compose([ScaleIntensity(), EnsureChannelFirst(), Resize((96, 96, 96))])\n",
    "\n",
    "# Define nifti dataset, data loader\n",
    "check_ds = ImageDataset(image_files=images, labels=ages, transform=train_transforms)\n",
    "check_loader = DataLoader(check_ds, batch_size=3, num_workers=2, pin_memory=pin_memory)\n",
    "\n",
    "im, label = monai.utils.misc.first(check_loader)\n",
    "print(type(im), im.shape, label, label.shape)\n",
    "\n",
    "# create a training data loader\n",
    "train_ds = ImageDataset(image_files=images[:10], labels=ages[:10], transform=train_transforms)\n",
    "train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=2, pin_memory=pin_memory)\n",
    "\n",
    "# create a validation data loader\n",
    "val_ds = ImageDataset(image_files=images[-10:], labels=ages[-10:], transform=val_transforms)\n",
    "val_loader = DataLoader(val_ds, batch_size=2, num_workers=2, pin_memory=pin_memory)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "YgH8asdqSaiT"
   },
   "source": [
    "## Create model and train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "peVze9d7eVPo",
    "is_executing": true
   },
   "outputs": [],
   "source": [
    "model = Regressor(in_shape=[1, 96, 96, 96], out_shape=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2))\n",
    "if torch.cuda.is_available():\n",
    "    model.cuda()\n",
    "# It is important that we use nn.MSELoss for regression.\n",
    "loss_function = torch.nn.MSELoss()\n",
    "\n",
    "optimizer = torch.optim.Adam(model.parameters(), 1e-4)\n",
    "\n",
    "# start a typical PyTorch training\n",
    "val_interval = 2\n",
    "best_metric = -1\n",
    "best_metric_epoch = -1\n",
    "epoch_loss_values = []\n",
    "metric_values = []\n",
    "writer = SummaryWriter()\n",
    "max_epochs = 5\n",
    "\n",
    "lowest_rmse = sys.float_info.max\n",
    "for epoch in range(max_epochs):\n",
    "    print(\"-\" * 10)\n",
    "    print(f\"epoch {epoch + 1}/{max_epochs}\")\n",
    "    model.train()\n",
    "    epoch_loss = 0\n",
    "    step = 0\n",
    "\n",
    "    for batch_data in train_loader:\n",
    "        step += 1\n",
    "        inputs, labels = batch_data[0].to(device), batch_data[1].to(device)\n",
    "        optimizer.zero_grad()\n",
    "        outputs = model(inputs)\n",
    "        loss = loss_function(outputs, labels.float())\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        epoch_loss += loss.item()\n",
    "        epoch_len = len(train_ds) // train_loader.batch_size\n",
    "        print(f\"{step}/{epoch_len}, train_loss: {loss.item():.4f}\")\n",
    "        writer.add_scalar(\"train_loss\", loss.item(), epoch_len * epoch + step)\n",
    "\n",
    "    epoch_loss /= step\n",
    "    epoch_loss_values.append(epoch_loss)\n",
    "    print(f\"epoch {epoch + 1} average loss: {epoch_loss:.4f}\")\n",
    "\n",
    "    if (epoch + 1) % val_interval == 0:\n",
    "        model.eval()\n",
    "        all_labels = []\n",
    "        all_val_outputs = []\n",
    "        for val_data in val_loader:\n",
    "            val_images, val_labels = val_data[0].to(device), val_data[1].to(device)\n",
    "            all_labels.extend(val_labels.cpu().detach().numpy())\n",
    "            with torch.no_grad():\n",
    "                val_outputs = model(val_images)\n",
    "                flattened_val_outputs = [val for sublist in val_outputs.cpu().detach().numpy() for val in sublist]\n",
    "                all_val_outputs.extend(flattened_val_outputs)\n",
    "\n",
    "        mse = np.square(np.subtract(all_labels, all_val_outputs)).mean()\n",
    "        rmse = np.sqrt(mse)\n",
    "\n",
    "        if rmse < lowest_rmse:\n",
    "            lowest_rmse = rmse\n",
    "            lowest_rmse_epoch = epoch + 1\n",
    "            torch.save(model.state_dict(), \"best_metric_model_classification3d_array.pth\")\n",
    "            print(\"saved new best metric model\")\n",
    "\n",
    "        print(f\"Current epoch: {epoch+1} current RMSE: {rmse:.4f} \")\n",
    "        print(f\"Best RMSE: {lowest_rmse:.4f} at epoch {lowest_rmse_epoch}\")\n",
    "        writer.add_scalar(\"val_rmse\", rmse, epoch + 1)\n",
    "\n",
    "print(f\"Training completed, lowest_rmse: {lowest_rmse:.4f} at epoch: {lowest_rmse_epoch}\")\n",
    "writer.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "TrwW4TLneVPp"
   },
   "source": [
    "## Cleanup data directory\n",
    "\n",
    "Remove directory if a temporary was used."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "66kVvM1eeVPp",
    "is_executing": true
   },
   "outputs": [],
   "source": [
    "if directory is None:\n",
    "    shutil.rmtree(root_dir)"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "T4",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
